"""
Phase 2: Allocation Puzzle
==========================
Does controlling for demographics resolve the Gourinchas-Jeanne puzzle?
(Net capital flows negatively correlated with productivity growth.)
"""

import pandas as pd
import numpy as np
from pathlib import Path
import sys
import warnings
warnings.filterwarnings('ignore')

PROJECT_DIR = Path(__file__).resolve().parent.parent
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

DATA = PROJECT_DIR / "data" / "processed"
OUT_TABLES = PROJECT_DIR / "output" / "tables"
OUT_TABLES.mkdir(parents=True, exist_ok=True)


def run_regression(df, y_var, x_vars, label, feature_names=None):
    """Run PanelGLS and return results dict."""
    cols = [y_var] + x_vars + ['iso3', 'year']
    sub = df[cols].dropna()
    if len(sub) < 50:
        print(f"  {label}: insufficient obs ({len(sub)}), skipping")
        return None

    gls = PanelGLS()
    y = sub[y_var].values
    X = sub[x_vars].values
    gls.fit(y, X, sub['iso3'].values, sub['year'].values)

    names = feature_names if feature_names else x_vars
    result = {
        'model': label,
        'n_obs': gls.n_obs,
        'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
        'rho': gls.rho,
    }
    for i, name in enumerate(names):
        result[f'{name}_coef'] = gls.beta[i]
        result[f'{name}_se'] = gls.se[i]
        result[f'{name}_p'] = gls.pvalues[i]

    # Print summary
    print(f"\n  {label} (N={gls.n_obs}, R²={gls.r_squared:.4f})")
    for i, name in enumerate(names):
        sig = '***' if gls.pvalues[i] < 0.01 else '**' if gls.pvalues[i] < 0.05 else '*' if gls.pvalues[i] < 0.1 else ''
        print(f"    {name:30s} {gls.beta[i]:8.4f} ({gls.se[i]:.4f}) {sig}")

    return result


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.1: return '*'
    return ''


def main():
    print("=" * 70)
    print("PHASE 2: ALLOCATION PUZZLE")
    print("=" * 70)

    df = pd.read_csv(DATA / "deepening_panel.csv")
    print(f"Panel: {len(df)} obs, {df['iso3'].nunique()} countries")

    results = []
    controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'log_rel_opw']

    # ── Model 1: Baseline puzzle (ca_gdp = β·growth + controls) ──
    print("\n--- Model 1: Baseline Allocation Puzzle ---")
    r = run_regression(df, 'ca_gdp',
                       ['rgdp_growth'] + controls,
                       'M1: Baseline Puzzle')
    if r: results.append(r)

    # ── Model 2: + Demographics ──
    print("\n--- Model 2: + Demographics (Z₁, Z₂, Z₃) ---")
    r = run_regression(df, 'ca_gdp',
                       ['rgdp_growth', 'Z_1', 'Z_2', 'Z_3'] + controls,
                       'M2: + Demographics')
    if r: results.append(r)

    # ── Model 3: + OADR only ──
    print("\n--- Model 3: + OADR ---")
    if 'old_dep' in df.columns:
        r = run_regression(df, 'ca_gdp',
                           ['rgdp_growth', 'old_dep'] + controls,
                           'M3: + OADR')
        if r: results.append(r)

    # ── Model 4a-c: By demographic stage ──
    print("\n--- Model 4: By Demographic Stage ---")
    if 'old_dep' in df.columns:
        # Compute terciles of old_dep
        terciles = df['old_dep'].quantile([0.33, 0.67])
        stages = {
            'M4a: Early Transition': df['old_dep'] <= terciles[0.33],
            'M4b: Mid Transition': (df['old_dep'] > terciles[0.33]) & (df['old_dep'] <= terciles[0.67]),
            'M4c: Late Transition': df['old_dep'] > terciles[0.67],
        }
        for label, mask in stages.items():
            sub = df[mask].copy()
            r = run_regression(sub, 'ca_gdp',
                               ['rgdp_growth'] + controls,
                               label)
            if r: results.append(r)

    # ── Model 5: Growth residual ──
    print("\n--- Model 5: Growth Residual (purged of demographics) ---")
    # Step a: regress growth on Z → residual
    cols = ['rgdp_growth', 'Z_1', 'Z_2', 'Z_3', 'iso3', 'year', 'ca_gdp'] + controls
    sub = df[cols].dropna()
    if len(sub) > 100:
        gls_aux = PanelGLS()
        gls_aux.fit(sub['rgdp_growth'].values,
                    sub[['Z_1', 'Z_2', 'Z_3']].values,
                    sub['iso3'].values, sub['year'].values)
        sub['growth_resid'] = gls_aux.resid

        r = run_regression(sub, 'ca_gdp',
                           ['growth_resid'] + controls,
                           'M5: Growth Residual')
        if r: results.append(r)

    # ── Model 6: R² decomposition ──
    print("\n--- Model 6: R² Decomposition ---")
    # Just growth
    r_growth = run_regression(df, 'ca_gdp', ['rgdp_growth'],
                              'M6a: Growth Only')
    if r_growth: results.append(r_growth)

    # Just demographics
    r_demo = run_regression(df, 'ca_gdp', ['Z_1', 'Z_2', 'Z_3'],
                            'M6b: Demographics Only')
    if r_demo: results.append(r_demo)

    # Both
    r_both = run_regression(df, 'ca_gdp',
                            ['rgdp_growth', 'Z_1', 'Z_2', 'Z_3'],
                            'M6c: Both')
    if r_both: results.append(r_both)

    # ── Model 7: Bilateral puzzle test ──
    print("\n--- Model 7: Bilateral Puzzle Test ---")
    # Need bilateral panel for this
    bp_path = ROOT_DIR / "gravity_bilateral" / "data" / "processed" / "bilateral_panel.csv"
    if bp_path.exists():
        bp = pd.read_csv(bp_path)
        # Compute growth differential
        # Need growth by country-year from full panel
        growth_map = df.set_index(['iso3', 'year'])['rgdp_growth'].to_dict()
        bp['growth_j'] = bp.apply(lambda r: growth_map.get((r['iso_d'], r['year']), np.nan), axis=1)
        bp['growth_i'] = bp.apply(lambda r: growth_map.get((r['iso_o'], r['year']), np.nan), axis=1)
        bp['growth_diff'] = bp['growth_j'] - bp['growth_i']

        # Model: log_portfolio_total = gravity + growth_diff + dZ
        bilat_vars = ['log_dist', 'contiguity', 'common_lang_official',
                      'log_gdp_product', 'growth_diff']
        sub_bp = bp.dropna(subset=bilat_vars + ['log_portfolio_total', 'pair_id', 'year'])
        if len(sub_bp) > 100:
            # Without demographics
            gls_b1 = PanelGLS()
            gls_b1.fit(sub_bp['log_portfolio_total'].values,
                       sub_bp[bilat_vars].values,
                       sub_bp['pair_id'].values,
                       sub_bp['year'].values)
            print(f"\n  M7a: Bilateral puzzle (no demo) N={gls_b1.n_obs}, R²={gls_b1.r_squared:.4f}")
            idx_gd = bilat_vars.index('growth_diff')
            print(f"    growth_diff: {gls_b1.beta[idx_gd]:.4f} (p={gls_b1.pvalues[idx_gd]:.4f})")
            results.append({
                'model': 'M7a: Bilateral (no demo)',
                'n_obs': gls_b1.n_obs, 'r_squared': gls_b1.r_squared,
                'growth_diff_coef': gls_b1.beta[idx_gd],
                'growth_diff_p': gls_b1.pvalues[idx_gd],
            })

            # With demographics
            bilat_vars2 = bilat_vars + ['dZ_1', 'dZ_2', 'dZ_3']
            sub_bp2 = bp.dropna(subset=bilat_vars2 + ['log_portfolio_total', 'pair_id', 'year'])
            gls_b2 = PanelGLS()
            gls_b2.fit(sub_bp2['log_portfolio_total'].values,
                       sub_bp2[bilat_vars2].values,
                       sub_bp2['pair_id'].values,
                       sub_bp2['year'].values)
            print(f"\n  M7b: Bilateral puzzle (+ demo) N={gls_b2.n_obs}, R²={gls_b2.r_squared:.4f}")
            idx_gd2 = bilat_vars2.index('growth_diff')
            print(f"    growth_diff: {gls_b2.beta[idx_gd2]:.4f} (p={gls_b2.pvalues[idx_gd2]:.4f})")
            for i, v in enumerate(['dZ_1', 'dZ_2', 'dZ_3']):
                idx = bilat_vars2.index(v)
                print(f"    {v}: {gls_b2.beta[idx]:.4f} (p={gls_b2.pvalues[idx]:.4f})")
            results.append({
                'model': 'M7b: Bilateral (+ demo)',
                'n_obs': gls_b2.n_obs, 'r_squared': gls_b2.r_squared,
                'growth_diff_coef': gls_b2.beta[idx_gd2],
                'growth_diff_p': gls_b2.pvalues[idx_gd2],
            })

    # ── Save results ──
    print("\n--- Saving results ---")
    results_df = pd.DataFrame(results)
    results_df.to_csv(OUT_TABLES / "allocation_puzzle_results.csv", index=False)

    # Create formatted markdown table
    with open(OUT_TABLES / "allocation_puzzle.md", 'w') as f:
        f.write("# Table 2: Allocation Puzzle — Does Demographics Resolve It?\n\n")
        f.write("| Model | N | R² | Growth β | Growth p | Z₁ β | Z₁ p |\n")
        f.write("|-------|---|-----|----------|----------|-------|------|\n")
        for r in results:
            growth_b = r.get('rgdp_growth_coef', r.get('growth_resid_coef',
                             r.get('growth_diff_coef', '')))
            growth_p = r.get('rgdp_growth_p', r.get('growth_resid_p',
                             r.get('growth_diff_p', '')))
            z1_b = r.get('Z_1_coef', r.get('dZ_1_coef', ''))
            z1_p = r.get('Z_1_p', r.get('dZ_1_p', ''))
            gb = f"{growth_b:.4f}" if isinstance(growth_b, float) else ''
            gp = f"{growth_p:.4f}" if isinstance(growth_p, float) else ''
            zb = f"{z1_b:.4f}" if isinstance(z1_b, float) else ''
            zp = f"{z1_p:.4f}" if isinstance(z1_p, float) else ''
            f.write(f"| {r['model']} | {r['n_obs']} | {r.get('r_squared', 0):.4f} "
                    f"| {gb} | {gp} | {zb} | {zp} |\n")
        f.write("\n")

    print("\nPhase 2 complete.")


if __name__ == '__main__':
    main()
